{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Copy of HW5 (NeuralNet) - Fine-Grained Malware Detection_charan",
      "provenance": [],
      "collapsed_sections": [
        "l0uviy0bMb0o"
      ],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "28JCZXo2qPQe",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow import feature_column\n",
        "from tensorflow.keras import layers\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.preprocessing import LabelEncoder,OneHotEncoder"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uV6KoDWB1E0U",
        "colab_type": "code",
        "outputId": "6a8b1ba6-3f20-4bca-eec3-fc721ddc673e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "from google.colab import drive \n",
        "drive.mount('/content/gdrive')"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WfFAIHKS2xxO",
        "colab_type": "code",
        "outputId": "a328cb5c-555e-4554-bfcf-10c231e2e39c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 224
        }
      },
      "source": [
        "import pandas as pd \n",
        "\n",
        "# Feature names from the file kddcup.names file to be used as cols heading\n",
        "col_names = [\"duration\",\"protocol_type\",\"service\",\"flag\",\"src_bytes\",\n",
        "    \"dst_bytes\",\"land\",\"wrong_fragment\",\"urgent\",\"hot\",\"num_failed_logins\",\n",
        "    \"logged_in\",\"num_compromised\",\"root_shell\",\"su_attempted\",\"num_root\",\n",
        "    \"num_file_creations\",\"num_shells\",\"num_access_files\",\"num_outbound_cmds\",\n",
        "    \"is_host_login\",\"is_guest_login\",\"count\",\"srv_count\",\"serror_rate\",\n",
        "    \"srv_serror_rate\",\"rerror_rate\",\"srv_rerror_rate\",\"same_srv_rate\",\n",
        "    \"diff_srv_rate\",\"srv_diff_host_rate\",\"dst_host_count\",\"dst_host_srv_count\",\n",
        "    \"dst_host_same_srv_rate\",\"dst_host_diff_srv_rate\",\"dst_host_same_src_port_rate\",\n",
        "    \"dst_host_srv_diff_host_rate\",\"dst_host_serror_rate\",\"dst_host_srv_serror_rate\",\n",
        "    \"dst_host_rerror_rate\",\"dst_host_srv_rerror_rate\",\"label\"]\n",
        "\n",
        "df = pd.read_csv(\"/content/gdrive/My Drive/data/kddcup.data\", header=None, names = col_names)\n",
        "\n",
        "# df=pd.read_csv('/content/gdrive/My Drive/data/kddcup.data_10_percent_corrected')\n",
        "df.head()"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>duration</th>\n",
              "      <th>protocol_type</th>\n",
              "      <th>service</th>\n",
              "      <th>flag</th>\n",
              "      <th>src_bytes</th>\n",
              "      <th>dst_bytes</th>\n",
              "      <th>land</th>\n",
              "      <th>wrong_fragment</th>\n",
              "      <th>urgent</th>\n",
              "      <th>hot</th>\n",
              "      <th>num_failed_logins</th>\n",
              "      <th>logged_in</th>\n",
              "      <th>num_compromised</th>\n",
              "      <th>root_shell</th>\n",
              "      <th>su_attempted</th>\n",
              "      <th>num_root</th>\n",
              "      <th>num_file_creations</th>\n",
              "      <th>num_shells</th>\n",
              "      <th>num_access_files</th>\n",
              "      <th>num_outbound_cmds</th>\n",
              "      <th>is_host_login</th>\n",
              "      <th>is_guest_login</th>\n",
              "      <th>count</th>\n",
              "      <th>srv_count</th>\n",
              "      <th>serror_rate</th>\n",
              "      <th>srv_serror_rate</th>\n",
              "      <th>rerror_rate</th>\n",
              "      <th>srv_rerror_rate</th>\n",
              "      <th>same_srv_rate</th>\n",
              "      <th>diff_srv_rate</th>\n",
              "      <th>srv_diff_host_rate</th>\n",
              "      <th>dst_host_count</th>\n",
              "      <th>dst_host_srv_count</th>\n",
              "      <th>dst_host_same_srv_rate</th>\n",
              "      <th>dst_host_diff_srv_rate</th>\n",
              "      <th>dst_host_same_src_port_rate</th>\n",
              "      <th>dst_host_srv_diff_host_rate</th>\n",
              "      <th>dst_host_serror_rate</th>\n",
              "      <th>dst_host_srv_serror_rate</th>\n",
              "      <th>dst_host_rerror_rate</th>\n",
              "      <th>dst_host_srv_rerror_rate</th>\n",
              "      <th>label</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>0</td>\n",
              "      <td>tcp</td>\n",
              "      <td>http</td>\n",
              "      <td>SF</td>\n",
              "      <td>215</td>\n",
              "      <td>45076</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.00</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>normal.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0</td>\n",
              "      <td>tcp</td>\n",
              "      <td>http</td>\n",
              "      <td>SF</td>\n",
              "      <td>162</td>\n",
              "      <td>4528</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>2</td>\n",
              "      <td>2</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.00</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>normal.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0</td>\n",
              "      <td>tcp</td>\n",
              "      <td>http</td>\n",
              "      <td>SF</td>\n",
              "      <td>236</td>\n",
              "      <td>1228</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>2</td>\n",
              "      <td>2</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.50</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>normal.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0</td>\n",
              "      <td>tcp</td>\n",
              "      <td>http</td>\n",
              "      <td>SF</td>\n",
              "      <td>233</td>\n",
              "      <td>2032</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>2</td>\n",
              "      <td>2</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>3</td>\n",
              "      <td>3</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.33</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>normal.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0</td>\n",
              "      <td>tcp</td>\n",
              "      <td>http</td>\n",
              "      <td>SF</td>\n",
              "      <td>239</td>\n",
              "      <td>486</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>3</td>\n",
              "      <td>3</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>4</td>\n",
              "      <td>4</td>\n",
              "      <td>1.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.25</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>normal.</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   duration protocol_type  ... dst_host_srv_rerror_rate    label\n",
              "0         0           tcp  ...                      0.0  normal.\n",
              "1         0           tcp  ...                      0.0  normal.\n",
              "2         0           tcp  ...                      0.0  normal.\n",
              "3         0           tcp  ...                      0.0  normal.\n",
              "4         0           tcp  ...                      0.0  normal.\n",
              "\n",
              "[5 rows x 42 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9yZoa8rvLuEa",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "## Data Preperation for training\n",
        "# ------------------------------\n",
        "def prepare_data(df):\n",
        "  '''This function prepares the dataset for training. \n",
        "  All categorical data will be encoded using LabelEncoder() of Sklearn. \n",
        "  Labels of each sample will be encoded into five categories as follows: -\n",
        "  \n",
        "    0 - Normal connection\n",
        "    1 - dos attack\n",
        "    2 - probe attack\n",
        "    3 - r2l attack\n",
        "    4 - u2r attack\n",
        "  '''\n",
        "  # Encoding the categorical label to five categories:\n",
        "  newlabeldf=df['label'].replace({ 'normal.' : 0, 'neptune.' : 1 ,'back.': 1, 'land.': 1, 'pod.': 1, 'smurf.': 1, 'teardrop.': 1,'mailbomb.': 1, 'apache2.': 1, 'processtable.': 1, 'udpstorm.': 1, 'worm.': 1,\n",
        "                           'ipsweep.' : 2,'nmap.' : 2,'portsweep.' : 2,'satan.' : 2,'mscan.' : 2,'saint.' : 2\n",
        "                           ,'ftp_write.': 3,'guess_passwd.': 3,'imap.': 3,'multihop.': 3,'phf.': 3,'spy.': 3,'warezclient.': 3,'warezmaster.': 3,'sendmail.': 3,'named.': 3,'snmpgetattack.': 3,'snmpguess.': 3,'xlock.': 3,'xsnoop.': 3,'httptunnel.': 3,\n",
        "                           'buffer_overflow.': 4,'loadmodule.': 4,'perl.': 4,'rootkit.': 4,'ps.': 4,'sqlattack.': 4,'xterm.': 4})\n",
        "  df['label'] = newlabeldf.astype('int')\n",
        "  \n",
        "  # Encoding categorical data using LabelEncoder()\n",
        "  le = LabelEncoder()\n",
        "  df['protocol_type'] = le.fit_transform(df['protocol_type'])\n",
        "  df['service']= le.fit_transform(df['service'])\n",
        "  df['flag'] = le.fit_transform(df['flag'])\n",
        "  \n",
        "  return df\n",
        "  "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "grxlSUbpp7R1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# X = df.iloc[:,:41]\n",
        "# y = df.iloc[:,-1].astype('int')\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wwEJLsMBTHVM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "df = prepare_data(df)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e8QINkS4spWf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "df_train, df_test = train_test_split(df, stratify=df['label'], test_size=0.25)\n",
        "df_train, df_val = train_test_split(df_train, stratify=df_train['label'], test_size=0.3333)\n",
        "\n",
        "df_train.to_csv(\"df_train.csv\",index=False)\n",
        "df_val.to_csv(\"df_val.csv\",index=False)\n",
        "df_test.to_csv(\"df_test.csv\",index=False)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bSm2Lm7qsrfT",
        "colab_type": "code",
        "outputId": "37bb7d91-6c5a-47c5-844a-77372d1bdcaa",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 119
        }
      },
      "source": [
        "df['label'].value_counts()"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1    3883370\n",
              "0     972781\n",
              "2      41102\n",
              "3       1126\n",
              "4         52\n",
              "Name: label, dtype: int64"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_r_yurjVsiR4",
        "colab_type": "code",
        "outputId": "712a8c0c-6671-4249-c732-969de9f6c831",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 119
        }
      },
      "source": [
        "df_train['label'].value_counts()"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1    1941781\n",
              "0     486415\n",
              "2      20552\n",
              "3        563\n",
              "4         26\n",
              "Name: label, dtype: int64"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l0uviy0bMb0o",
        "colab_type": "text"
      },
      "source": [
        "## Visualizing the catogaries (normal + four attack types)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WaoDh6iTL4f3",
        "colab_type": "code",
        "outputId": "7bb98647-1695-40cb-d963-1e92a624f194",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 226
        }
      },
      "source": [
        "from numpy import where\n",
        "from matplotlib import pyplot as plt\n",
        "from sklearn.decomposition import PCA\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "\n",
        "X = df.iloc[:,:41]\n",
        "X = StandardScaler().fit_transform(X)\n",
        "y = df.iloc[:,-1].astype('int')\n",
        "\n",
        "pd.DataFrame(data = X, columns = col_names[:-1]).head()"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>duration</th>\n",
              "      <th>protocol_type</th>\n",
              "      <th>service</th>\n",
              "      <th>flag</th>\n",
              "      <th>src_bytes</th>\n",
              "      <th>dst_bytes</th>\n",
              "      <th>land</th>\n",
              "      <th>wrong_fragment</th>\n",
              "      <th>urgent</th>\n",
              "      <th>hot</th>\n",
              "      <th>num_failed_logins</th>\n",
              "      <th>logged_in</th>\n",
              "      <th>num_compromised</th>\n",
              "      <th>root_shell</th>\n",
              "      <th>su_attempted</th>\n",
              "      <th>num_root</th>\n",
              "      <th>num_file_creations</th>\n",
              "      <th>num_shells</th>\n",
              "      <th>num_access_files</th>\n",
              "      <th>num_outbound_cmds</th>\n",
              "      <th>is_host_login</th>\n",
              "      <th>is_guest_login</th>\n",
              "      <th>count</th>\n",
              "      <th>srv_count</th>\n",
              "      <th>serror_rate</th>\n",
              "      <th>srv_serror_rate</th>\n",
              "      <th>rerror_rate</th>\n",
              "      <th>srv_rerror_rate</th>\n",
              "      <th>same_srv_rate</th>\n",
              "      <th>diff_srv_rate</th>\n",
              "      <th>srv_diff_host_rate</th>\n",
              "      <th>dst_host_count</th>\n",
              "      <th>dst_host_srv_count</th>\n",
              "      <th>dst_host_same_srv_rate</th>\n",
              "      <th>dst_host_diff_srv_rate</th>\n",
              "      <th>dst_host_same_src_port_rate</th>\n",
              "      <th>dst_host_srv_diff_host_rate</th>\n",
              "      <th>dst_host_serror_rate</th>\n",
              "      <th>dst_host_srv_serror_rate</th>\n",
              "      <th>dst_host_rerror_rate</th>\n",
              "      <th>dst_host_srv_rerror_rate</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>-0.067792</td>\n",
              "      <td>0.925753</td>\n",
              "      <td>-0.104067</td>\n",
              "      <td>0.514274</td>\n",
              "      <td>-0.002879</td>\n",
              "      <td>0.138664</td>\n",
              "      <td>-0.006673</td>\n",
              "      <td>-0.04772</td>\n",
              "      <td>-0.002571</td>\n",
              "      <td>-0.044136</td>\n",
              "      <td>-0.009782</td>\n",
              "      <td>2.39698</td>\n",
              "      <td>-0.005679</td>\n",
              "      <td>-0.010552</td>\n",
              "      <td>-0.004676</td>\n",
              "      <td>-0.00564</td>\n",
              "      <td>-0.011232</td>\n",
              "      <td>-0.009919</td>\n",
              "      <td>-0.027632</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>-0.037263</td>\n",
              "      <td>-1.521417</td>\n",
              "      <td>-1.15664</td>\n",
              "      <td>-0.46409</td>\n",
              "      <td>-0.46352</td>\n",
              "      <td>-0.24796</td>\n",
              "      <td>-0.248631</td>\n",
              "      <td>0.536987</td>\n",
              "      <td>-0.255243</td>\n",
              "      <td>-0.203633</td>\n",
              "      <td>-3.451536</td>\n",
              "      <td>-1.694315</td>\n",
              "      <td>0.599396</td>\n",
              "      <td>-0.282867</td>\n",
              "      <td>-1.022077</td>\n",
              "      <td>-0.158629</td>\n",
              "      <td>-0.464418</td>\n",
              "      <td>-0.463202</td>\n",
              "      <td>-0.25204</td>\n",
              "      <td>-0.249464</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>-0.067792</td>\n",
              "      <td>0.925753</td>\n",
              "      <td>-0.104067</td>\n",
              "      <td>0.514274</td>\n",
              "      <td>-0.002820</td>\n",
              "      <td>-0.011578</td>\n",
              "      <td>-0.006673</td>\n",
              "      <td>-0.04772</td>\n",
              "      <td>-0.002571</td>\n",
              "      <td>-0.044136</td>\n",
              "      <td>-0.009782</td>\n",
              "      <td>2.39698</td>\n",
              "      <td>-0.005679</td>\n",
              "      <td>-0.010552</td>\n",
              "      <td>-0.004676</td>\n",
              "      <td>-0.00564</td>\n",
              "      <td>-0.011232</td>\n",
              "      <td>-0.009919</td>\n",
              "      <td>-0.027632</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>-0.037263</td>\n",
              "      <td>-1.521417</td>\n",
              "      <td>-1.15664</td>\n",
              "      <td>-0.46409</td>\n",
              "      <td>-0.46352</td>\n",
              "      <td>-0.24796</td>\n",
              "      <td>-0.248631</td>\n",
              "      <td>0.536987</td>\n",
              "      <td>-0.255243</td>\n",
              "      <td>-0.203633</td>\n",
              "      <td>-3.297085</td>\n",
              "      <td>-1.600011</td>\n",
              "      <td>0.599396</td>\n",
              "      <td>-0.282867</td>\n",
              "      <td>-1.146737</td>\n",
              "      <td>-0.158629</td>\n",
              "      <td>-0.464418</td>\n",
              "      <td>-0.463202</td>\n",
              "      <td>-0.25204</td>\n",
              "      <td>-0.249464</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>-0.067792</td>\n",
              "      <td>0.925753</td>\n",
              "      <td>-0.104067</td>\n",
              "      <td>0.514274</td>\n",
              "      <td>-0.002824</td>\n",
              "      <td>0.014179</td>\n",
              "      <td>-0.006673</td>\n",
              "      <td>-0.04772</td>\n",
              "      <td>-0.002571</td>\n",
              "      <td>-0.044136</td>\n",
              "      <td>-0.009782</td>\n",
              "      <td>2.39698</td>\n",
              "      <td>-0.005679</td>\n",
              "      <td>-0.010552</td>\n",
              "      <td>-0.004676</td>\n",
              "      <td>-0.00564</td>\n",
              "      <td>-0.011232</td>\n",
              "      <td>-0.009919</td>\n",
              "      <td>-0.027632</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>-0.037263</td>\n",
              "      <td>-1.521417</td>\n",
              "      <td>-1.15664</td>\n",
              "      <td>-0.46409</td>\n",
              "      <td>-0.46352</td>\n",
              "      <td>-0.24796</td>\n",
              "      <td>-0.248631</td>\n",
              "      <td>0.536987</td>\n",
              "      <td>-0.255243</td>\n",
              "      <td>-0.203633</td>\n",
              "      <td>-3.142633</td>\n",
              "      <td>-1.505707</td>\n",
              "      <td>0.599396</td>\n",
              "      <td>-0.282867</td>\n",
              "      <td>-1.188291</td>\n",
              "      <td>-0.158629</td>\n",
              "      <td>-0.464418</td>\n",
              "      <td>-0.463202</td>\n",
              "      <td>-0.25204</td>\n",
              "      <td>-0.249464</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>-0.067792</td>\n",
              "      <td>0.925753</td>\n",
              "      <td>-0.104067</td>\n",
              "      <td>0.514274</td>\n",
              "      <td>-0.002840</td>\n",
              "      <td>0.014179</td>\n",
              "      <td>-0.006673</td>\n",
              "      <td>-0.04772</td>\n",
              "      <td>-0.002571</td>\n",
              "      <td>-0.044136</td>\n",
              "      <td>-0.009782</td>\n",
              "      <td>2.39698</td>\n",
              "      <td>-0.005679</td>\n",
              "      <td>-0.010552</td>\n",
              "      <td>-0.004676</td>\n",
              "      <td>-0.00564</td>\n",
              "      <td>-0.011232</td>\n",
              "      <td>-0.009919</td>\n",
              "      <td>-0.027632</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>-0.037263</td>\n",
              "      <td>-1.530800</td>\n",
              "      <td>-1.16476</td>\n",
              "      <td>-0.46409</td>\n",
              "      <td>-0.46352</td>\n",
              "      <td>-0.24796</td>\n",
              "      <td>-0.248631</td>\n",
              "      <td>0.536987</td>\n",
              "      <td>-0.255243</td>\n",
              "      <td>-0.203633</td>\n",
              "      <td>-2.988182</td>\n",
              "      <td>-1.411403</td>\n",
              "      <td>0.599396</td>\n",
              "      <td>-0.282867</td>\n",
              "      <td>-1.188291</td>\n",
              "      <td>-0.158629</td>\n",
              "      <td>-0.464418</td>\n",
              "      <td>-0.463202</td>\n",
              "      <td>-0.25204</td>\n",
              "      <td>-0.249464</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>-0.067792</td>\n",
              "      <td>0.925753</td>\n",
              "      <td>-0.104067</td>\n",
              "      <td>0.514274</td>\n",
              "      <td>-0.002842</td>\n",
              "      <td>0.035214</td>\n",
              "      <td>-0.006673</td>\n",
              "      <td>-0.04772</td>\n",
              "      <td>-0.002571</td>\n",
              "      <td>-0.044136</td>\n",
              "      <td>-0.009782</td>\n",
              "      <td>2.39698</td>\n",
              "      <td>-0.005679</td>\n",
              "      <td>-0.010552</td>\n",
              "      <td>-0.004676</td>\n",
              "      <td>-0.00564</td>\n",
              "      <td>-0.011232</td>\n",
              "      <td>-0.009919</td>\n",
              "      <td>-0.027632</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>-0.037263</td>\n",
              "      <td>-1.530800</td>\n",
              "      <td>-1.16476</td>\n",
              "      <td>-0.46409</td>\n",
              "      <td>-0.46352</td>\n",
              "      <td>-0.24796</td>\n",
              "      <td>-0.248631</td>\n",
              "      <td>0.536987</td>\n",
              "      <td>-0.255243</td>\n",
              "      <td>-0.203633</td>\n",
              "      <td>-2.833731</td>\n",
              "      <td>-1.317100</td>\n",
              "      <td>0.599396</td>\n",
              "      <td>-0.282867</td>\n",
              "      <td>-1.209067</td>\n",
              "      <td>-0.158629</td>\n",
              "      <td>-0.464418</td>\n",
              "      <td>-0.463202</td>\n",
              "      <td>-0.25204</td>\n",
              "      <td>-0.249464</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   duration  protocol_type  ...  dst_host_rerror_rate  dst_host_srv_rerror_rate\n",
              "0 -0.067792       0.925753  ...              -0.25204                 -0.249464\n",
              "1 -0.067792       0.925753  ...              -0.25204                 -0.249464\n",
              "2 -0.067792       0.925753  ...              -0.25204                 -0.249464\n",
              "3 -0.067792       0.925753  ...              -0.25204                 -0.249464\n",
              "4 -0.067792       0.925753  ...              -0.25204                 -0.249464\n",
              "\n",
              "[5 rows x 41 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M6URlefYs9Ji",
        "colab_type": "code",
        "outputId": "575c6ba1-58f5-4f04-fbb1-e0fe1e9dc912",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        }
      },
      "source": [
        "covar_matrix = PCA(n_components = 41)\n",
        "covar_matrix.fit(X)\n",
        "variance = covar_matrix.explained_variance_ratio_ #calculate variance ratios\n",
        "\n",
        "var=np.cumsum(np.round(covar_matrix.explained_variance_ratio_, decimals=3)*100)\n",
        "var #cumulative sum of variance explained with [n] features"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([26. , 37.9, 46.9, 54.5, 59.2, 63.3, 66.4, 69.3, 72.2, 74.8, 77.4,\n",
              "       79.9, 82.4, 84.8, 87. , 89.2, 91.2, 93.1, 94.9, 95.9, 96.8, 97.7,\n",
              "       98.3, 98.8, 99.2, 99.5, 99.6, 99.7, 99.8, 99.9, 99.9, 99.9, 99.9,\n",
              "       99.9, 99.9, 99.9, 99.9, 99.9, 99.9, 99.9, 99.9])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rIEsq9aCtWEl",
        "colab_type": "code",
        "outputId": "e4bb1fa6-2901-4db5-a41f-43a50273f3ab",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 312
        }
      },
      "source": [
        "plt.ylabel('% Variance Explained')\n",
        "plt.xlabel('# of Features')\n",
        "plt.title('PCA Analysis')\n",
        "plt.ylim(20,100.5)\n",
        "plt.style.context('seaborn-whitegrid')\n",
        "\n",
        "\n",
        "plt.plot(var)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[<matplotlib.lines.Line2D at 0x7fbfcc508cc0>]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xd8HOW1//HPkVzkLnfLRS7YYGzj\nhlxouQkdw8UUh5DQMTjccLmQhFBuSCFcSkghJL+E4IRiWiimGAjdGBNCYlvuFdxk4yI3We5F5fz+\nmBEIsZLXklaz0n7fr5deOzM7s8/xvKw9mueZOY+5OyIiIhWlRR2AiIgkJyUIERGJSQlCRERiUoIQ\nEZGYlCBERCQmJQgREYlJCUIkiZhZnpmdWsPP2G1mfWorJkldShBS74VfqvvCL8ZNZva4mbUs9/4Z\nZvahme0ysy1mNt3Mzq3wGV83MzezW+Nss7eZlZrZQ7X976kpd2/p7quijkPqPyUIaSj+091bAsOB\nHOAOADMbB7wAPAF0BzoDPwX+s8LxVwAFwOVxtnc5sB34lpk1rXH0IklICUIaFHdfD7wJDDIzA34L\n3OXuf3X3He5e6u7T3f3asmPMrAUwDrge6GdmOVW1EX7u5QRJqIgKySa8ErnOzJabWaGZ/TE8BjM7\nwszeN7NtZrbVzJ42s8wYbXQxs71m1r7ctuHhFVBjM+sbXgntCD/nuQrt9w2Xx5jZkvDqab2Z3XyY\np1RSmBKENChm1gMYA8wFjgJ6AJMPcdgFwG6CK423Ca4mqnIiwdXIs8Dzlex/DjACGAxcBJxRFiJw\nL9AVODqM7+cVD3b3fOCD8NgylwHPunsRcBfwDtA2jOUPlcT6CPBdd28FDALeP8S/TeRzShDSULxi\nZoXAR8B04B6g7K/vjYc49grgOXcvAZ4BLjazxofY/0133x7uf6aZdaqwz33uXujua4FpwFAAd1/h\n7u+6+wF330JwhfMflbQzCbgUwMzSgW8DT4bvFQE9ga7uvt/dP6rkM4qAAWbW2t23u/ucKv5dIl+i\nBCENxXnununuPd39e+6+D9gWvpdV2UHhFcc3gKfDTVOADODsSvZvBnyzbH93/xewFvhOhV3zyy3v\nBVqGx3c2s2fD7p6dwFNAh0rCm0Lw5d4bOA3Y4e4zw/duIbgamWlmi83s6ko+40KCK6o1YZfUcZXs\nJ/IVShDSkH0CfEbwJVmZywh+D14zs3xgFUGCqKyb6XygNfAnM8sPj+lWxf4V3QM4cIy7tya4QrBY\nO7r7foIurEvDOJ8s916+u1/r7l2B74bx9I3xGbPcfSzQCXgl/DyRuChBSIPlQS37HwA/MbOrzKy1\nmaWZ2YlmNjHc7QrgToIuoLKfC4Ex5QeIy7kCeBQ4ptz+JwBDzOyYOMJqRTDescPMugE/OsT+TwBX\nAudSLkGY2TfNrHu4up0g6ZSWP9DMmpjZJWbWJhy32FlxH5GqKEFIg+buk4FvAVcDG4BNwP8BU8xs\nNEE//h/Dv8jLfl4FVhD0+X8u/EI/Bfhdhf1nA28R31XEnQS34u4A/g68dIj4/0nwpT7H3deUe2sE\nMMPMdgOvAjdW8uzDZUBe2J11HXBJHDGKAGCaMEgkuZnZ+8Az7v7XqGOR1KIEIZLEzGwE8C7Qw913\nRR2PpBZ1MYkkKTObBLwH3KTkIFHQFYSIiMSUsCsIM3vUzDab2aJy29qZ2bthCYJ3zaxtuN3M7Pdm\ntsLMFpjZ8ETFJSIi8UnYFYSZfY3gdr4n3H1QuO1+oMDd7zOz24C27n6rmY0BbiB4oGcU8KC7jzpU\nGx06dPBevXolJH4RkYZq9uzZW92946H2a5SoANz9QzPrVWHzWODr4fIkglozt4bbnwjvW/+3mWWa\nWZa7V1kioVevXuTm5tZm2CIpr6TU2V9UEvwUl36xXFTC/qLSL78Wf7F8oML++4q+WD5QVEqJurNr\n1bUn9eHMQV2qdayZrTn0XglMEJXoXO5LP5+g9DIET6J+Vm6/deG2ryQIM5sATADIzs5OXKQi9diu\n/UVs2rmf/B0H2LhjX7C8cz+bdx5gX7kv+/LLB8Iv/KKS6n+RZzROI6NxOhmN0r9YbhwsN0lLr8V/\noTROj/kAfq2q6wTxOXd3Mzvs/4nuPhGYCJCTk6M/SSQllZY6G3bsY822veRt28OabXtZvXUPa7bt\nYUPhfnYfKP7KMW2bN6ZTqwyaNw2+wDu0bESzJsFy08bplXy5B69NY3zhB/ul06xxOk0bp9G0URph\nVXNpIOo6QWwq6zoysyxgc7h9PUHZ4zLdw20iKa201Pls+16WbtzFJ/m7WJa/k+Wbd7O2YC8Hi7+o\nmtG0URo92zenZ/sWHH9EB7LaZNClTQZdWgevnVtnkNFYf8HL4anrBPEqQTmC+8LXKeW2/7eZPUsw\nSL3jUOMPIg3NgeISlm7cxfzPClm6cSfL8nfx6aZd7D1YAoAZ9GzXnCM7t+KU/p3o1aEFPds3p1f7\nFnRpnUFamv56l9qVsARhZn8jGJDuYGbrgJ8RJIbnzWw8sIYvJkN5g+AOphUEpZGvSlRcIsnA3Vm9\ndQ/z1xUyb20h89btYOmGnRwsCa4K2jZvTP8urbkopwdHZ7XiqC6tObJzS5o3iaxXWFJQIu9i+nYl\nb50SY18nmO5RpMHasbeIqcs28fbifP69qoAd+4oAaN4kncHd23DVib0Y2j2TIT0yyWqTof58iZz+\nHBFJoM279vPukk28tSiff63cRnGp06V1BmcN6sKw7CAZ9OvUinR1D0kSUoIQqWWfFezl7cX5vL04\nn9w123GHXu2bM/6k3pw5sAtDumdqvEDqBSUIkRpyd1Zs3s1bi/J5a3E+izfsBODorNbcdMqRnDmo\nC0d2bqkuI6l3lCBEqsHdWbBuB28tzuftRfms2roHgOHZmfzvmP6cMbALPdu3iDhKkZpRghA5DMvy\nd/LK3A28Nn8D6wv3kZ5mHNenPVed2JvTB3Smc+uMqEMUqTVKECKHsL5wH6/O28CUeetZlr+L9DTj\npH4d+P5pR3Lq0Z3IbN4k6hBFEkIJQiSGPQeKmTJvA6/MW8/M1QVA0H1057kDOXtwFh1aNo04QpHE\nU4IQqWDaJ5u54+VFrC/cxxEdW/DD045k7NBuZLdvHnVoInVKCUIktHX3Ae56fQlT5m2gb6eWPDth\nNKN6t9PdR5KylCAk5bk7L81Zz11/X8KeA8XcdGo//uvrR9C0kYrbSWpTgpCUtnbbXn78ykL+sXwr\nx/Zsy30XHEO/zq2iDkskKShBSEoqKXUe/Wg1v3n3ExqlpXHX2IFcMqqnnnAWKUcJQlLOqi27+eEL\n85m7tpBTj+7EXecNIqtNs6jDEkk6ShCSMkpLncc+zuP+t5aR0TidBy8eyrlDumoQWqQSShCSEtZs\n28OPXljAzLwCTunfiXsvOIZOeupZpEpKENKglZY6T89Yw71vLiPdjF+NG8y4Y7vrqkEkDkoQ0mCt\nL9zHLZPn888V2zipXwd+eeFgumZqrEEkXkoQ0iBNmbeeO15eRKk795x/DN8e2UNXDSKHKZIEYWY3\nAtcCBvzF3X9nZu2A54BeQB5wkbtvjyI+qb927i/ip68s4pV5Gzi2Z1t+962h9GinEhki1ZFW1w2a\n2SCC5DASGAKcY2Z9gduAqe7eD5garovEbVZeAWf97h+8tmAjPzjtSJ6bMFrJQaQGoriCOBqY4e57\nAcxsOnABMBb4erjPJOAD4NYI4pN6pqiklN9PXc4fp62ge9vmvHDdcQzPbht1WCL1XhQJYhFwt5m1\nB/YBY4BcoLO7bwz3yQc6xzrYzCYAEwCys7MTH60ktbyte7jxuXnM/6yQccd25+fnDqRlUw2tidSG\nOv9NcvelZvZL4B1gDzAPKKmwj5uZV3L8RGAiQE5OTsx9JDVMmbee219aSKM044/fGc7Zg7OiDkmk\nQYnkTy13fwR4BMDM7gHWAZvMLMvdN5pZFrA5itgk+R0sLuWeN5by+Md5jOjVlgcvHqbbV0USIKq7\nmDq5+2YzyyYYfxgN9AauAO4LX6dEEZskt00793P903PIXbOdq0/oze1j+tM4vc7vtRBJCVF11r4Y\njkEUAde7e6GZ3Qc8b2bjgTXARRHFJklqxqptXP/MXPYcKOb33x7GuUO6Rh2SSIMWVRfTSTG2bQNO\niSAcSXLuziMfrebeN5fRs11znrl2FEdqzgaRhNPtHpLU9hwo5pYXF/D3BRs5Y2Bnfv3NIbTKaBx1\nWCIpQQlCktbKLbu57snZrNyym1vP7M91/9FH5TJE6pAShCSltxZt5OYXFtCkURpPjh/FCX07RB2S\nSMpRgpCkUlxSyq/e+YSHp69iSI9MHrpkuG5hFYmIEoQkja27D3DDM3P516ptXDo6m5+cM4CmjdKj\nDkskZSlBSFKYs3Y733tqDtv3HuTX3xzCuGO7Rx2SSMpTgpBIuTtPzVjLL15bTJc2Gbz0veMZ2LVN\n1GGJCEoQEqGDxaX85JVFPJf7GSf378QDFw2lTXPdwiqSLJQgJBI79hZx3VOz+deqbdxwcl++f+qR\npKXpFlaRZKIEIXUub+serp40i3UF+3jgW0M4f5jGG0SSkRKE1KmZqwv47pO5ADx1zShG9m4XcUQi\nUhklCKkzL89dx62TF9K9bTMevXIEvTq0iDokEamCEoQknLvzwHvL+f3U5RzXpz0PXTqczOZNog5L\nRA5BCUISan9RCbdMXsCr8zfwzWO7c/f5x9CkkeZvEKkPlCAkYXbsLeLaJ3OZubqAH51xFN/7+hEq\ntidSjyhBSEJsKNzHlY/NZPXWPTx48VDGDu0WdUgicpiUIKTWLd24kysfm8neAyVMumokx6sSq0i9\npAQhterjlVv57hOzadG0Ec9fdxxHZ7WOOiQRqaZIRgvN7PtmttjMFpnZ38wsw8x6m9kMM1thZs+Z\nmW5zqWemzFvPFY/OJCszqKmk5CBSv9V5gjCzbsD/ADnuPghIBy4Gfgk84O59ge3A+LqOTarH3Zn4\n4UpufHYew7Pb8sJ3j9ccDiINQFT3GzYCmplZI6A5sBE4GZgcvj8JOC+i2OQwlJQ6d762hHveWMbZ\ng7N4YvxIFdwTaSDqfAzC3deb2a+BtcA+4B1gNlDo7sXhbuuAmLe9mNkEYAJAdnZ24gOWSu0vKuH7\nz83jzUX5jD+xNz8ec7QK7ok0IJUmCDPbBXhl77t7tTqYzawtMBboDRQCLwBnxnu8u08EJgLk5ORU\nGp8kVuHeg1z7RC6z8rZzx9lHc81JfaIOSURqWaUJwt1bAZjZXQRdQE8CBlwCZNWgzVOB1e6+Jfz8\nl4ATgEwzaxReRXQH1tegDUmgddv3cuVjs1i7bS//7zvDOGdw16hDEpEEiGcM4lx3/5O773L3ne7+\nEMEVQHWtBUabWXMLHqs9BVgCTAPGhftcAUypQRuSIIs37OCCP33Mpp37eWL8SCUHkQYsngSxx8wu\nMbN0M0szs0uAPdVt0N1nEAxGzwEWhjFMBG4FfmBmK4D2wCPVbUMS46PlW/nWw/8mPc2YfN3xjO7T\nPuqQRCSB4hmk/g7wYPjjwD/DbdXm7j8DflZh8ypgZE0+VxLn5bnr+NELC+jbqSWPXTWCrDa6jVWk\noTtkgnD3PGrWpST13MQPV3LPG8s4rk97Hr78WFpn6DZWkVRwyC4mMzvSzKaa2aJwfbCZ3ZH40CRq\n7s69by79/BmHx68eoeQgkkLiGYP4C3A7UATg7gsInnyWBqyk1Ln9pYU8PH0Vl4zK5vcXD6Npo/So\nwxKROhTPGERzd59ZoY5/cWU7S/23v6iEm56dx1uL87nh5L784LQjNY+DSAqKJ0FsNbMjCB+aM7Nx\nBM9FSAO0+0AxE57I5eOV2/jJOQMYf2LvqEMSkYjEkyCuJ7gNtb+ZrQdWA5cmNCqJRMGeg1z52EwW\nb9jJb745hAuP7R51SCISoXjuYloFnGpmLYA0d9+V+LCkrm0o3Mdlj8xg3fZ9PHzpsZw6oHPUIYlI\nxA6ZIMysKXAh0AtoVNYX7e6/SGhkUmdWbN7N5Y/MYNf+Yp64eiSj9ACciBBfF9MUYAdBxdUDiQ1H\n6tqCdYVc+dgs0gz+NmE0g7q1iTokEUkS8SSI7u4ed7VVqT8+XrmVayflktm8CU9dM4reHVpEHZKI\nJJF4noP42MyOSXgkUqfeWpTPlY/OolvbZrz4X8crOYjIV8RzBXEicKWZrSboYjLA3X1wQiOThHl+\n1mfc9tIChvTI5LErR5DZXNN/i8hXxZMgzkp4FFJnHp6+knvfXMZJ/Trw8GXH0rxJnU8qKCL1RFUz\nyrV2952AbmttANyd+95axsPTV3H24CweuGgoTRpFNSW5iNQHVf35+AxwDsHdS07QtVTGAc0xWU+4\nOz+dspgn/72G74zK5q6xg0jX3NEicghVTTl6TviqWgv1mLvzi9eX8OS/1/Ddr/XhtrP6q66SiMQl\nrg5oM2sL9AMyyra5+4eJCkpqR1m30mP/zOOqE3opOYjIYYnnSeprgBuB7sA8YDTwL+DkxIYmNfXA\ne8t5ePoqLh2dzU/PGaDkICKHJZ5RyhuBEcAad/8GMAworG6DZnaUmc0r97PTzG4ys3Zm9q6ZLQ9f\n21a3DYE/TlvB76cu56Kc7vzi3EFKDiJy2OJJEPvdfT8EdZncfRlwVHUbdPdP3H2ouw8FjgX2Ai8D\ntwFT3b0fMDVcl2r4y4er+NXbn3D+sG7ce8Fg0jQgLSLVEE+CWGdmmcArwLtmNgVYU0vtnwKsdPc1\nBPNeTwq3TwLOq6U2Usqkj/O4+42lnH1MFr8aN1h3K4lItcVT7vv8cPHnZjYNaAO8VUvtXwz8LVzu\n7O5lExHlAzHrTZvZBGACQHZ2di2F0TD8beZafvbqYk4f0JnfXTyURul6zkFEqs/cPfYbZu2qOtDd\nC2rUsFkTYAMw0N03mVmhu2eWe3+7u1c5DpGTk+O5ubk1CaPBeHH2Om6ePJ+vH9mRP192rOaPFpFK\nmdlsd8851H5VXUHEekCuTG08KHcWMMfdN4Xrm8wsy903mlkWsLmGn58yXp2/gR9Nns8JR3TgoUuV\nHESkdlT1oFyiH5D7Nl90LwG8ClwB3Be+Tklw+w3Cmws38v3n5jGiVzv+cnkOGY2VHESkdsT7oNwF\nBFVdHfiHu79Sk0bD6UtPA75bbvN9wPNmNp5gEPyimrSRCt5bsokb/jaXoT0yefTKETRrouQgIrUn\nngfl/gT05Yu/9q8zs9Pc/frqNurue4D2FbZtI7irSeIw/dMtfO/pOQzs2prHrhpBi6aqyioitSue\nb5WTgaM9HM02s0nA4oRGJVX6eMVWJjyRS99OLXni6lG0zmgcdUgi0gDFcx/kCqD8/aQ9wm0SgVl5\nBYyflEvP9s156ppRtGmu5CAiiRHPFUQrYKmZzSQYgxgJ5JrZqwDufm4C45Ny5q7dzlWPzSIrM4On\nrxlNuxaaCU5EEieeBPHThEchh/Tppl1c+dgs2rdswjPXjKZjq6ZRhyQiDVw8CWKLuy8pv8HMvu7u\nHyQmJKlofeE+Ln9kJk0bpfHU+FF0aZNx6INERGoonjGI583sFgs0M7M/APcmOjAJFOw5yGWPzGDP\nwWImXT2SHu2aRx2SiKSIeBLEKIJB6o+BWQTlMU5IZFAS2HuwmKsfn8W67fv46+U5HJ3VOuqQRCSF\nxJMgioB9QDOCGeVWu3tpQqMSikpK+a+n5rBgXSF/+PYwRvVpf+iDRERqUTwJYhZBghgBnAR828xe\nSGhUKa601Lll8gKmf7qFe84/hjMGdok6JBFJQfEMUo9397KSqRuBsWZ2WQJjSmnuzj1vLOXlueu5\n+fQjuXikSpqLSDQqvYIws5MB3D3XzCoW7tuT0KhS2MQPV/HXj1Zz5fG9uP4bfaMOR0RSWFVdTL8u\nt/xihffuSEAsKe+NhRu5981lnDM4i5+eM0DzSItIpKpKEFbJcqx1qaFNO/dz+0sLGdIjk99cNETz\nSItI5KpKEF7Jcqx1qQH3YFD6QHEJv71oiCb8EZGkUNUgdZ+w3pKVWyZcT/RkQinlmZlrmf7pFu48\ndyBHdGwZdTgiIkDVCWJsueVfV3iv4rpU05pte7j770s5sW8HLhvdM+pwREQ+V9WUo9PrMpBUVFLq\n/PD5+aSnGfePG6xxBxFJKpqGLEJ/+ccqctds54FvDaFrZrOowxER+ZJ4nqSudWaWaWaTzWyZmS01\ns+PMrJ2ZvWtmy8PXtlHEVleWbtzJb9/5lLMGdeG8od2iDkdE5CviThBmVptlRB8E3nL3/sAQYClw\nGzDV3fsBU8P1BulAcQnff24erZs15v/OG6TnHUQkKR0yQZjZ8Wa2BFgWrg8xsz9Vt0EzawN8DXgE\nwN0PunshwaD4pHC3ScB51W0j2T343nKW5e/ivguOoX1LTfwjIskpniuIB4AzgG0A7j6f4Au+unoD\nW4DHzGyumf3VzFoAnd19Y7hPPtA51sFmNsHMcs0sd8uWLTUIIxqz1xTw5+kr+VZOD04dEPOfKCKS\nFOLqYnL3zypsKqlBm42A4cBD7j6MoK7Tl7qT3N2p5GE8d5/o7jnuntOxY8cahFH39heV8MPn59M1\nsxl3nHN01OGIiFQpngTxmZkdD7iZNTazmwnGDKprHbDO3WeE65MJEsYmM8sCCF8316CNpPTn6SvJ\n27aX+y8cTKuMxlGHIyJSpXgSxHXA9UA3YD0wNFyvFnfPJ0g6R4WbTgGWAK8CV4TbrgCmVLeNZLR2\n217+9MFK/nNIV47v2yHqcEREDumQz0G4+1bgklpu9wbgaTNrAqwCriJIVs+b2XhgDXBRLbcZqV+8\nvpjGacaPx6hrSUTqh0MmCDObBNwY3mlE+HzCb9z96uo26u7zgJwYb51S3c9MZlOXbuK9pZv53zH9\n6dImI+pwRETiEk8X0+Cy5ADg7tuBYYkLqWHZX1TCz19bTN9OLbnqBNU4FJH6I54EkVb+qWYza4dK\ndMTtz9NX8lnBPn5x7kAap0fy4LqISLXE80X/G+BfZvYCQanvccDdCY2qgVi7bS8PfbCScwZnaWBa\nROqdeAapnzCz2cA3wk0XuPuSxIbVMPzi9cWkpxk/PlsD0yJS/8TbVbQM2F62v5llu/vahEXVALy/\nLBiYvv2s/mS1UaVWEal/4rmL6QbgZ8AmgieojeAp58GJDa3+2l9Uws9fXcIRHVtoYFpE6q14riBu\nBI5y922JDqaheHj6KtYW7OXpa0bRpJEGpkWkfoqr1AawI9GBNBSfFezlTx+s4OzBWZyggWkRqcfi\nuYJYBXxgZn8HDpRtdPffJiyqeuyB9z4lzYw7NDAtIvVcPAlibfjTJPyRSuzaX8QbCzdywfDuGpgW\nkXovnttc76yLQBqCNxfms7+olAuHd486FBGRGovnLqaOwC3AQODzQkLufnIC46qXJs9ZR+8OLRie\nnRl1KCIiNRbPIPXTBM9B9AbuBPKAWQmMqV76rGAvM1cXcOHwbppjWkQahHgSRHt3fwQocvfpYRVX\nXT1U8OKcdZjB+epeEpEGIp5B6qLwdaOZnQ1sANolLqT6x915ac56juvTnm6ZGpwWkYYhngTxf2bW\nBvgh8AegNfD9hEZVz+Su2c7agr3ceEq/qEMREak18dzF9Hq4uIMvCvZJOS/OXkfzJumcOahL1KGI\niNSaShOEmd3i7veb2R8Iai99ibv/T0Ijqyf2F5Xw9wUbOWtQFi2aapoMEWk4qvpGWxq+5tZ2o2aW\nB+wiKP5X7O454UREzwG9CO6UuiicvS6pvb04n10HirlweLeoQxERqVWVJgh3f83M0oFj3P3mBLT9\nDXffWm79NmCqu99nZreF67cmoN1a9eKc9XTLbMboPu2jDkVEpFZVeZuru5cAJ9RRLGOBSeHyJOC8\nOmq32jbt3M9Hy7dw/rBupKXp2QcRaVji6TSfZ2avAi8Ae8o2uvtLNWjXgXfMzIGH3X0i0NndN4bv\n5wOdYx1oZhOACQDZ2dk1CKHmXp67nlKHC9S9JCINUDwJIgPYxpcfjnOgJgniRHdfb2adgHfNbFn5\nN93dw+TxFWEymQiQk5MTc5+64O68OHsdw7Mz6dOxZVRhiIgkTDy3uV5V2426+/rwdbOZvQyMBDaZ\nWZa7bzSzLGBzbbdbmxau38Hyzbu5+/xBUYciIpIQ8RTrywDG89VifVdXp0EzawGkufuucPl04BfA\nq8AVwH3h65TqfH5deXH2Opo0SuOcwV2jDkVEJCHiqcX0JNAFOAOYDnQnuEW1ujoDH5nZfGAm8Hd3\nf4sgMZxmZsuBU8P1pHSwuJRX52/gtAGdadOscdThiIgkRDxjEH3d/ZtmNtbdJ5nZM8A/qtugu68C\nhsTYvg04pbqfW5emfbKZ7XuLGKfCfCLSgMVzBVFWrK/QzAYBbYBOiQsp+b04ex0dWjblpH6ac1pE\nGq54EsREM2sL3EEwTrAE+GVCo0piBXsOMu2TzZw3tCuN0uM5fSIi9VNVtZi6uHu+u/813PQh0Kdu\nwkpe7yzOp6jEOW+Ynn0QkYatqj+B55nZe2Y23sw0h2bonSWb6N62GQO7to46FBGRhKoqQXQDfgWc\nCHxiZlPM7GIzS9kZcXYfKOajFVs5fUAXTSsqIg1epQnC3Uvc/e3wQbkewKME9ZJWm9nTdRVgMvnw\n0y0cLC7l9IExq4CIiDQocY2yuvtBgsHppcBO4OhEBpWs3lmcT9vmjcnp2TbqUEREEq7KBGFmPczs\nR2Y2B3g93P9cdx9eJ9ElkaKSUqYu28ypR3fW3UsikhKquovpY4JxiOeBa919dp1FlYRmrCpg1/5i\nTh+oaUVFJDVU9ST1bcA/3D2yiqnJ5O3F+TRrnK6H40QkZVQ1o9yHdRlIMistdd5dsomvHdmBjMbp\nUYcjIlIn1Jkeh4Xrd5C/cz+nD1D3koikDiWIOLyzJJ/0NOOUo1O6BJWIpJi4E4SZjTazt8zsAzNL\n+vmia9M7izcxqnc7Mps3iToUEZE6U2mCMLOK/Sk/AM4HxgB3JTKoZLJqy26Wb97N6QP0cJyIpJaq\n7mL6c/j8w/3uvh8oBMYBpQQPy6WEd5dsAuA03d4qIimmqlIb5wFzgdfN7HLgJqAp0B5ImS6md5Zs\nYlC31nTLTNkSVCKSoqocg3BLiwlmAAAM1ElEQVT31wimGm0DvAx86u6/d/ctdRFc1Dbv2s+ctdt1\n95KIpKSqxiDONbNpwFvAIuBbwFgze9bMjqhpw2aWbmZzzez1cL23mc0wsxVm9pyZRT4i/N6Szbij\n4nwikpKquoL4P+As4CLgl+5e6O4/BH4C3F0Lbd9IUPyvzC+BB9y9L7AdGF8LbdTIO0vyyW7XnKM6\nt4o6FBGROldVgtgBXABcCGwu2+juy9394po0ambdgbOBv4brBpwMTA53mUTE4xy79hfx8YptnDGw\ns+Z+EJGUVFWCOJ9gQLoR8J1abvd3wC0Ed0QRtlPo7sXh+jqCQoFfYWYTzCzXzHK3bEncUMj0T7dw\nsKRUxflEJGVVdRfTVnf/g7v/2d1r7bZWMzsH2Fzd6rDuPtHdc9w9p2PHjrUV1le8s3gT7Vs0YXi2\n5n4QkdRU1XMQiXICcK6ZjQEygNbAg0CmmTUKryK6A+sjiA2Ag8WlTFu2mTHHZJGepu4lEUlNdV6L\nyd1vd/fu7t4LuBh4390vAaYRPIgHcAUwpa5jK/PvVdvYdaBYdy+JSEpLpmJ9twI/MLMVBGMSj0QV\nyDtL8mneJJ0T+mruBxFJXVF0MX3O3T8APgiXVwEjo4ynzLRlW/hav46a+0FEUloyXUEkhXXb97K+\ncB/HHdE+6lBERCKlBFFBbt52AHJ66e4lEUltShAVzMwroFXTRvTv0jrqUEREIqUEUUFuXgHH9mqr\n21tFJOUpQZSzfc9BPt20mxG92kUdiohI5JQgysldE4w/KEGIiChBfEluXgFN0tMY3L1N1KGIiERO\nCaKcmXkFDO7eRs8/iIigBPG5fQdLWLR+BznqXhIRAZQgPjfvs0KKSpyRvfX8g4gIKEF8blZeAWZw\nbLauIEREQAnic7PyCjiqcyvaNG8cdSgiIklBCQIoLillzprtur1VRKQcJQhg6cZd7DlYovpLIiLl\nKEEQdC8BjOytKwgRkTJKEAQJonvbZmS1aRZ1KCIiSSPlE4S7MytP4w8iIhWlfILI27aXrbsPKEGI\niFRQ5wnCzDLMbKaZzTezxWZ2Z7i9t5nNMLMVZvacmTWpi3hmrQ7GH0ZogFpE5EuiuII4AJzs7kOA\nocCZZjYa+CXwgLv3BbYD4+simFl5BbRt3pi+nVrWRXMiIvVGnScID+wOVxuHPw6cDEwOt08CzquL\neGblFXBsz3aYaYIgEZHyIhmDMLN0M5sHbAbeBVYChe5eHO6yDuhWybETzCzXzHK3bNlSozg279pP\n3ra9qr8kIhJDJAnC3UvcfSjQHRgJ9D+MYye6e46753Ts2LFGceTmaYIgEZHKRHoXk7sXAtOA44BM\nM2sUvtUdWJ/o9meuLiCjcRoDu2qCIBGRiqK4i6mjmWWGy82A04ClBIliXLjbFcCURMeSu6aAYT3a\n0qRRyt/tKyLyFVF8M2YB08xsATALeNfdXwduBX5gZiuA9sAjiQxi1/4ilmzYqdtbRUQq0ejQu9Qu\nd18ADIuxfRXBeESdmLu2kFKHEaq/JCISU8r2rczKKyA9zRiWrSsIEZFYUjZBzFxdwICs1rRsWucX\nUSIi9UJKJoiDxaXM+6xQt7eKiFQhJRPEwvU7OFBcqgFqEZEqpGSCKJsgKEdXECIilUrJDvgxg7Lo\n2LIpHVs1jToUEZGklZIJIrt9c7LbN486DBGRpJaSXUwiInJoShAiIhKTEoSIiMSkBCEiIjEpQYiI\nSExKECIiEpMShIiIxKQEISIiMSlBiIhITEoQIiISkxKEiIjEVOcJwsx6mNk0M1tiZovN7MZwezsz\ne9fMloevqsUtIhKhKK4gioEfuvsAYDRwvZkNAG4Dprp7P2BquC4iIhGp8wTh7hvdfU64vAtYCnQD\nxgKTwt0mAefVdWwiIvKFSMt9m1kvYBgwA+js7hvDt/KBzpUcMwGYEK7uNrNPqtl8B2BrNY9NJMV1\neBTX4UvW2BTX4alJXD3j2cncvZqfXzNm1hKYDtzt7i+ZWaG7Z5Z7f7u7J2wcwsxy3T0nUZ9fXYrr\n8Ciuw5essSmuw1MXcUVyF5OZNQZeBJ5295fCzZvMLCt8PwvYHEVsIiISiOIuJgMeAZa6+2/LvfUq\ncEW4fAUwpa5jExGRL0QxBnECcBmw0Mzmhdv+F7gPeN7MxgNrgIsSHMfEBH9+dSmuw6O4Dl+yxqa4\nDk/C44psDEJERJKbnqQWEZGYlCBERCSmlEwQZnammX1iZivMLGme2DazPDNbaGbzzCw3wjgeNbPN\nZrao3LbIS6FUEtfPzWx9eM7mmdmYCOJKyvIxVcQV6Tkzswwzm2lm88O47gy39zazGeHv5XNm1iRJ\n4nrczFaXO19D6zKucvGlm9lcM3s9XE/8+XL3lPoB0oGVQB+gCTAfGBB1XGFseUCHJIjja8BwYFG5\nbfcDt4XLtwG/TJK4fg7cHPH5ygKGh8utgE+BAVGfsyriivScAQa0DJcbEzwoOxp4Hrg43P5n4L+S\nJK7HgXFR/h8LY/oB8Azwerie8POVilcQI4EV7r7K3Q8CzxKU+ZCQu38IFFTYHHkplEriipwnafmY\nKuKKlAd2h6uNwx8HTgYmh9ujOF+VxRU5M+sOnA38NVw36uB8pWKC6AZ8Vm59HUnwSxNy4B0zmx2W\nFEkmcZVCich/m9mCsAsq0irA1SkfUxcqxAURn7Owu2QewQOx7xJc1Re6e3G4SyS/lxXjcvey83V3\neL4eMLOmdR0X8DvgFqA0XG9PHZyvVEwQyexEdx8OnEVQ5fZrUQcUiwfXtEnxlxXwEHAEMBTYCPwm\nqkDC8jEvAje5+87y70V5zmLEFfk5c/cSdx8KdCe4qu9f1zHEUjEuMxsE3E4Q3wigHXBrXcZkZucA\nm919dl22C6mZINYDPcqtdw+3Rc7d14evm4GXCX5xkkVSlkJx903hL3Up8BciOmfJWj4mVlzJcs7C\nWAqBacBxQKaZlT28G+nvZbm4zgy76tzdDwCPUffn6wTgXDPLI+gSPxl4kDo4X6mYIGYB/cI7AJoA\nFxOU+YiUmbUws1Zly8DpwKKqj6pTSVkKpewLOHQ+EZyzZC0fU1lcUZ8zM+toZpnhcjPgNILxkWnA\nuHC3KM5XrLiWlUvyRtDPX6fny91vd/fu7t6L4PvqfXe/hLo4X1GPzEfxA4whuKNjJfDjqOMJY+pD\ncEfVfGBxlHEBfyPoeigi6NscT9DnORVYDrwHtEuSuJ4EFgILCL6QsyKI60SC7qMFwLzwZ0zU56yK\nuCI9Z8BgYG7Y/iLgp+H2PsBMYAXwAtA0SeJ6Pzxfi4CnCO90iuIH+Dpf3MWU8POlUhsiIhJTKnYx\niYhIHJQgREQkJiUIERGJSQlCRERiUoIQEZGYlCCkwTKze83sG2Z2npndfpjHdgwrZc41s5MqvPeB\nBdWAy6p7jqvscw7Rxk1m1rw6x4rUBSUIachGAf8G/gP48DCPPQVY6O7D3P0fMd6/xN2Hhj+TY7wf\nj5uAw0oQ5Z6cFUk4JQhpcMzsV2a2gKB2zr+Aa4CHzOynMfbtZWbvh4XYpppZdljv/35gbHiF0CzO\ndi8N5xOYZ2YPm1l6uP0hM8utMMfA/wBdgWlmNi3ctrvcZ40zs8fD5cfN7M9mNgO4P3zq/tGwrblm\nNjbcb2C59heYWb/qnkMR0JzU0kCZ2QjgcoIa+h+4+wmV7PcaMNndJ5nZ1cC57n6emV0J5Lj7f8c4\n5gOCuRb2hZtOAToRJJUL3L3IzP4E/NvdnzCzdu5eECaMqcD/uPuCsLZOjrtvDT93t7u3DJfHAee4\n+5VhougAjHX3EjO7B1ji7k+FpSFmElRqvS9s8+mwjEy6u5fFKHLYdLkqDdVwgrIl/Qnq/FTmOOCC\ncPlJgi/5eFzi7p/P+mdm3waOBWYFJXtoxhfF+S4Ky7c3IkgsAwjKORyOF9y9JFw+naB4283hegaQ\nTXC19GML5g54yd2XH2YbIl+iBCENStg99DhBdcutBH38Ftb4Py6Bf1EbMMndvzQYbma9gZuBEe6+\nPbwayKjkM8pfzlfcZ0+Fti50908q7LM07IY6G3jDzL7r7u8f5r9D5HMag5AGxd3neVDPv2x6zfeB\nM8LB5FjJ4WOCCpkAlwCxBqTjMRUYZ2ad4PP5qHsCrQm+3HeYWWeCuT7K7CKYCrTMJjM72szSCKqs\nVuZt4IawuihmNix87QOscvffE1T2HFzNf4sIoAQhDZCZdQS2ezDfQX93X1LF7jcAV4WD2pcBN1an\nzbCNOwhmBFxAMEtalrvPJ6gQuoxgPuF/ljtsIvBW2SA1wbzVrxMkrY1U7i6C6TAXmNnicB3gImBR\neLU0CHiiOv8WkTIapBYRkZh0BSEiIjEpQYiISExKECIiEpMShIiIxKQEISIiMSlBiIhITEoQIiIS\n0/8HO8OkJXHLeHQAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-0S7muQJ5L4R",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tQCadc-j5MWG",
        "colab_type": "text"
      },
      "source": [
        "##Training "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UFXLESNSOsHP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# A utility method to create a tf.data dataset from a Pandas Dataframe\n",
        "def df_to_dataset(dataframe, shuffle=True, batch_size=64):\n",
        "  dataframe = dataframe.copy()\n",
        "  labels = dataframe.pop('label')\n",
        "  ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))\n",
        "  if shuffle:\n",
        "    ds = ds.shuffle(buffer_size=len(dataframe))\n",
        "  ds = ds.batch(batch_size)\n",
        "  return ds"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uV3S2Ywbq6Mr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# train, test = train_test_split(df, test_size=0.2)\n",
        "# train, val = train_test_split(train, test_size=0.2)\n",
        "# print(len(train), 'train examples')\n",
        "# print(len(val), 'validation examples')\n",
        "# print(len(test), 'test examples')\n",
        "\n",
        "# batch size is a hyperparameter that defines the number of samples to work\n",
        "# Adjust the batch_size as per the RAM availability before modeling.\n",
        "batch_size = 10000\n",
        "\n",
        "# creating the tf.data dataset\n",
        "train_ds = df_to_dataset(df_train, batch_size=batch_size)\n",
        "val_ds = df_to_dataset(df_val, shuffle=False, batch_size=batch_size)\n",
        "test_ds = df_to_dataset(df_test, shuffle=False, batch_size=batch_size)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eTUjgq25q6Qz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Generating feature layer for keras sequential model.\n",
        "feature_columns = []\n",
        "for header in list(df.columns)[:-1]:\n",
        "  feature_columns.append(feature_column.numeric_column(header))\n",
        "feature_layer = tf.keras.layers.DenseFeatures(feature_columns)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z4oChvXgq6Us",
        "colab_type": "code",
        "outputId": "fed94c32-0cb7-4598-c1f6-d2064fe01b7d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "# Create, compile, and train the model\n",
        "model = tf.keras.Sequential([\n",
        "  feature_layer,\n",
        "  layers.Dense(128, activation='relu'),\n",
        "  layers.Dense(5, activation='softmax'),\n",
        "  \n",
        "])\n",
        "\n",
        "model.compile(optimizer='adamax',\n",
        "              loss='sparse_categorical_crossentropy',\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "# model.fit(train_ds,\n",
        "#           validation_data=val_ds,\n",
        "#           epochs=5,\n",
        "#           callbacks=[TensorBoardColabCallback(tbc)])\n",
        "\n",
        "model.fit(train_ds,\n",
        "          validation_data=val_ds,\n",
        "          epochs=1,workers=10\n",
        "          )"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Train on 245 steps, validate on 123 steps\n",
            "245/245 [==============================] - 213s 869ms/step - loss: 200.2837 - acc: 0.9473 - val_loss: 27.9690 - val_acc: 0.9858\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<tensorflow.python.keras.callbacks.History at 0x7efb9a39ee48>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gdKM5YdarNxc",
        "colab_type": "code",
        "outputId": "5a97834b-52ee-465c-9f4c-53e1f5a4d620",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "# Testing on the heldout dataset from training samples.\n",
        "loss, accuracy = model.evaluate(test_ds)\n",
        "print(\"Accuracy on heldout test dataset: \", accuracy)\n"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "123/123 [==============================] - 45s 362ms/step - loss: 93.9236 - acc: 0.9857\n",
            "Accuracy on heldout test dataset:  0.9857244\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P1iav98K0zdT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "y_true = df_test.iloc[:,-1].astype('int')\n",
        "y_pred=model.predict(test_ds)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4gbyrvy1GOQy",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        },
        "outputId": "5021616a-0608-4f7d-a7df-023c53f8b78f"
      },
      "source": [
        "from sklearn.metrics import classification_report,confusion_matrix\n",
        "confusion_matrix(y_true, y_pred.argmax(axis=1))"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[231763,   5530,    865,   4982,     55],\n",
              "       [  1050, 968856,    352,    585,      0],\n",
              "       [  2575,   1214,   6486,      1,      0],\n",
              "       [   255,      3,      0,     21,      2],\n",
              "       [    11,      1,      1,      0,      0]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W-g18Rk2Hjpe",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 275
        },
        "outputId": "f7cf4f71-b783-4720-f218-cd989de097c1"
      },
      "source": [
        "print(classification_report(y_true, y_pred.argmax(axis=1)))"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/metrics/classification.py:1437: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n",
            "  'precision', 'predicted', average, warn_for)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.98      0.99      0.99    243195\n",
            "           1       1.00      1.00      1.00    970843\n",
            "           2       0.89      0.62      0.73     10276\n",
            "           3       0.00      0.00      0.00       281\n",
            "           4       0.00      0.00      0.00        13\n",
            "\n",
            "    accuracy                           0.99   1224608\n",
            "   macro avg       0.57      0.52      0.54   1224608\n",
            "weighted avg       0.99      0.99      0.99   1224608\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}